# -*-Python-*-

import mesh_tensorflow.transformer.transformer_layers

d_ff = 512

encoder/transformer.make_layer_stack.layers = [
    @transformer_layers.LocalConvAttnBlock,
    @transformer_layers.DenseReluDense,
  ]

decoder/transformer.make_layer_stack.layers = [
    @transformer_layers.LocalConvAttnBlock,
    @transformer_layers.EncDecAttention,
    @transformer_layers.DenseReluDense,
  ]

transformer_layers.LocalConvAttnBlock.output_size = %d_model
transformer_layers.LocalConvAttnBlock.num_unique_depth_filters = 2
transformer_layers.LocalConvAttnBlock.attention_type = "dynamic_conv"

encoder/transformer.transformer_layers.LocalConvAttnBlock.max_relative_pos = 4
encoder/transformer.transformer_layers.LocalConvAttnBlock.min_relative_pos = -4

decoder/transformer.transformer_layers.LocalConvAttnBlock.max_relative_pos = 0
decoder/transformer.transformer_layers.LocalConvAttnBlock.min_relative_pos = -8

